In [8]:
# !pip install kaleido numpy pandas neurokit2 plotly seaborn ts2vg pyxdf

Imports and helper functios¶

In [9]:
import json

import numpy as np
import pandas as pd
import neurokit2 as nk
import plotly.io as pio
import plotly.express as px
from _plotly_utils.colors import n_colors
from matplotlib import pyplot as plt
import plotly.graph_objs as go
import warnings
import glob
from operator import itemgetter
import os

from util import plot_data, plot_channels, plot_gantt, markers_to_gantt, plot_epoch, plot_correlation_matrix, \
    seconds_to_samples, samples_to_seconds, g
from IPython.display import display

pio.renderers.default = 'notebook_connected+jupyterlab'
In [10]:
# Adapted from Neurokit's read_xdf() to work with the data from this experiment
# https://neuropsychology.github.io/NeuroKit/functions/data.html#neurokit2.data.read_xdf
def read_xdf(subject, upsample=2, fillmissing=None):
    """**Read and tidy an XDF file**"""

    def get_markers(markers_stream):
        markers = markers_stream['time_series']
        assert all(len(marker) == 1 for marker in markers), 'Warning: There is an event containing more than one marker'
        markers = [marker[0] for marker in markers]

        designs = ["fair", "dark"]
        dark_first = json.load(open(g(subject, '*meta.json')[0], 'r'))['darkFirst']
        if dark_first:
            designs = designs[::-1]
        events = [
            "app/start", "cookies/start", "cookies/end",
            "geolocation/start", "geolocation/end",
            "notification/start", "notification/end",
            "travelProtection/start", "travelProtection/end",
            "newsletter/start", "newsletter/end", "app/end"
        ]
        expected_marker_order = [f"{d}/{e}" for d in designs for e in events]

        # Only keep unique markers occurring in expected order
        relevant_indices = []
        previous_index = 0
        for marker in expected_marker_order:
            if any(marker in item for item in markers[previous_index:]):
                index = next((i for i in range(previous_index, len(markers)) if marker in markers[i]), None)
                if index is None:
                    raise ValueError(f"No element containing '{marker}' found in markers from index {previous_index}")
                relevant_indices.append(index)
                previous_index = index
            # else:
            #     print(f'Missing marker: {marker}')

        timestamps = itemgetter(*relevant_indices)(markers_stream['time_stamps'])
        markers = itemgetter(*relevant_indices)(markers)
        return markers, timestamps

    try:
        import pyxdf
    except ImportError:
        raise ImportError(
            "The 'pyxdf' module is required for this function to run. ",
            "Please install it first (`pip install pyxdf`).",
        )

    # Load file
    print(f"Reading xdf file for subject: {subject}")
    streams, header = pyxdf.load_xdf(g(subject, '*.xdf')[0])

    # Remove any empty streams
    streams = [stream for stream in streams if len(stream['time_series'])]

    # Process markers stream first
    markers_stream = next(filter(lambda stream: isinstance(stream['time_series'], list), streams))
    streams.remove(markers_stream)
    markers, timestamps = get_markers(markers_stream)

    # print(f"Markers: {markers}")
    # print(f"Timestamps: {timestamps}")

    # Get the time range for analysis (from first to last marker)
    min_marker_time = min(timestamps)
    max_marker_time = max(timestamps)

    # Find the actual data range across all streams
    all_stream_times = []
    for stream in streams:
        all_stream_times.extend(stream["time_stamps"])

    data_start_time = min(all_stream_times)
    data_end_time = max(all_stream_times)

    # print(f"Data time range: {data_start_time} to {data_end_time}")
    # print(f"Marker time range: {min_marker_time} to {max_marker_time}")

    # Use the overlap between data and markers as the analysis window
    analysis_start = max(data_start_time, min_marker_time)
    analysis_end = min(data_end_time, max_marker_time)

    # print(f"Analysis window: {analysis_start} to {analysis_end}")

    # Set offset to analysis start
    offset = analysis_start

    markers_df = pd.DataFrame(markers, columns=['marker'])
    markers_df["marker"] = markers_df["marker"].str.split("/").str[3:6].apply("/".join)
    markers_df.index = pd.to_datetime(timestamps - offset, unit="s")

    # Process other streams and convert to dataframes
    dfs = []
    for stream in streams:
        time_mask = (stream["time_stamps"] >= analysis_start) & (stream["time_stamps"] <= analysis_end)

        if not np.any(time_mask):
            print(f"Warning: No data in analysis window for stream")
            continue

        filtered_timestamps = stream["time_stamps"][time_mask]
        filtered_data = stream["time_series"][time_mask]

        print(f"Stream data after filtering: {len(filtered_timestamps)} samples")
        # print(f"Time range: {filtered_timestamps.min() - offset} to {filtered_timestamps.max() - offset}")

        channels_info = stream["info"]["desc"][0]["channels"][0]["channel"]
        cols = [channels_info[i]["label"][0] for i in range(len(channels_info))]
        dat = pd.DataFrame(filtered_data, columns=cols)

        # Apply offset to timestamps
        dat.index = pd.to_datetime(filtered_timestamps - offset, unit="s")
        dfs.append(dat)

    if not dfs:
        raise ValueError("No valid data found in analysis window")

    # print(f"Number of data streams: {len(dfs)}")

    # Store info of each stream
    info = {
        "sampling_rates_original": [float(s["info"]["nominal_srate"][0]) for s in streams],
        "sampling_rates_effective": [float(s["info"]["effective_srate"]) for s in streams],
        "datetime": header["info"]["datetime"][0],
        "data": dfs,
        "subject": subject,
    }

    # Merge all dataframes by timestamps
    streams_df = dfs[0]
    for i in range(1, len(dfs)):
        streams_df = pd.merge(streams_df, dfs[i], how="outer", left_index=True, right_index=True)
    streams_df = streams_df.sort_index()

    # print(f"Merged dataframe shape: {streams_df.shape}")
    # print(f"Time range: {streams_df.index.min()} to {streams_df.index.max()}")

    # Resample and Interpolate
    info["sampling_rate"] = int(np.max(info["sampling_rates_original"]) * upsample)
    print(f"Target sampling rate: {info['sampling_rate']} Hz")

    if fillmissing is not None:
        fillmissing = int(info["sampling_rate"] * fillmissing)

    # Create new index with evenly spaced timestamps
    idx = pd.date_range(
        streams_df.index.min(),
        streams_df.index.max(),
        freq=str(1000 / info["sampling_rate"]) + "ms"
    )

    # print(f"New index length: {len(idx)}")

    # Reindex and interpolate
    streams_df = streams_df.reindex(streams_df.index.union(idx))

    # Only interpolate numeric columns
    numeric_cols = streams_df.select_dtypes(include=[np.number]).columns
    streams_df[numeric_cols] = streams_df[numeric_cols].interpolate(method="time", limit=fillmissing)

    # Use the new evenly spaced index
    streams_df = streams_df.reindex(idx)

    # Final data validation
    # print(f"Final dataframe shape: {streams_df.shape}")
    # print(f"NaN counts by column:")
    # for col in streams_df.columns:
    #     nan_count = streams_df[col].isna().sum()
    #     if nan_count > 0:
    #         print(f"  {col}: {nan_count} NaNs ({nan_count / len(streams_df) * 100:.1f}%)")

    return streams_df, markers_df, info
In [11]:
def get_signals_and_events(streams_df, markers_df, info, plot=True):
    # Clean up the streams

    # Rename some columns for clarity
    rename_columns = {
        'RAW0': 'EDA',
    }

    # Keep only the columns we want
    columns_to_keep = [*rename_columns.values()]

    streams_df = streams_df.rename(columns=rename_columns)
    streams_df = streams_df[columns_to_keep]

    if False:
        plot_channels(
            streams_df, markers_df, title='Raw Channel Data', hide_end_markers=True)
        plot_gantt(markers_df)

    # Process EDA
    eda_signals, eda_info = nk.eda_process(streams_df['EDA'], sampling_rate=info["sampling_rate"],
                                           method_cleaning="biosppy")
    if plot:
        nk.eda_plot(eda_signals, eda_info)

    # Concatenate processed signals
    signals = pd.concat([eda_signals], axis=1)

    # Reindex markers with numeric index for further processing
    nearest_indices = streams_df.index.get_indexer(markers_df.index, method='nearest')
    markers_numindexed = markers_df.copy()
    markers_numindexed.index = nearest_indices  # Use integer sample numbers

    if plot:
        # Plot processed signals (only some columns)
        columns_to_plot = ['EDA_Tonic', 'EDA_Phasic']
        plot_channels(signals[columns_to_plot], markers_numindexed, title='Processed Channel Data',
                      hide_end_markers=True)
        plot_gantt(markers_df)

    # Remove the app markers
    markers_to_remove = '|'.join(['/app'])
    # markers_df = markers_df[~markers_df.marker.str.contains(markers_to_remove)]
    markers_numindexed = markers_numindexed[~markers_numindexed.marker.str.contains(markers_to_remove)]

    # Create events dictionary from gantt data for event-related analysis
    gantt_data = markers_to_gantt(markers_numindexed)

    labels = [d['marker'] for d in gantt_data]

    designs, event_names = zip(*[d['marker'].split('/')[:2] for d in gantt_data])

    events = dict(
        onset=[d['start'] for d in gantt_data],
        duration=[d['duration'] for d in gantt_data],
        label=labels,
        condition=labels,
        designs=designs,
        event_names=event_names
    )

    return signals, events
In [12]:
def get_epoch_features(signals, events, info, plot=True):
    # Build epochs from events
    epochs = nk.epochs_create(signals, events, sampling_rate=info["sampling_rate"], epochs_start=-1, epochs_end=5)

    # Create a non-epoch signals dataframe
    epoch_onsets = events['onset']
    epoch_lengths = [len(e) for e in epochs.values()]
    epoch_indices = []
    for start, length in zip(epoch_onsets, epoch_lengths):
        epoch_indices.extend(range(start, start + length))
    non_epoch_signals = signals.drop(index=epoch_indices)

    if plot:
        for epoch in epochs.values():
            plot_epoch(epoch, subplots=True)
            fig = plt.gcf()
            fig.suptitle(f"Epoch from -1 to 5 seconds for Event: {epoch['Condition'].values[0]}")
            fig.show()

    # Analyze epochs and extract features
    bio_epoch_features = nk.bio_analyze(epochs, sampling_rate=info["sampling_rate"])

    epoch_features = pd.concat([bio_epoch_features], axis=1)

    # Calculate some additional features
    for event in epochs.keys():
        epoch_features.loc[event, 'total_duration'] = samples_to_seconds(
            events['duration'][events['label'].index(event)], info["sampling_rate"])
        epoch_features.loc[event, 'EDA_Phasic_Mean'] = epochs[event]['EDA_Phasic'].mean()
        epoch_features.loc[event, 'EDA_Tonic_Max'] = max(epochs[event]['EDA_Tonic'])

    epoch_features['event_name'] = [event_name for event_name in events['event_names']]
    epoch_features['design'] = [design for design in events['designs']]
    epoch_features = epoch_features.drop(columns=["Event_Onset", "Label", "Condition"])

    # Do interval analysis to get features about the non-epoch signals
    non_epoch_features = nk.bio_analyze(non_epoch_signals, sampling_rate=info["sampling_rate"], method="interval")

    return epochs, epoch_features, non_epoch_features

Load and process all data¶

In [13]:
if not all([os.path.isfile(fname) for fname in
            ['non_epoch_features.csv', 'epoch_features.csv', 'epoch_signals.csv'
             # , 'metadata.csv'
             ]]):
    # Processed data hasn't been saved yet, so process raw data
    non_epoch_features_dfs = []
    dfs = []  # Epoch features
    epoch_signal_dfs = []
    meta_dfs = []
    for i, subject in enumerate(sorted([d for d in os.listdir('data') if not d.startswith('.')])[1:], start=1):
        # Read and clean subject's XDF file and extract streams and markers as dataframes
        streams_df, markers_df, info = read_xdf(subject, upsample=2)

        # print(streams_df)

        # Further process the streams and markers
        signals, events = get_signals_and_events(streams_df, markers_df, info, plot=1)

        # Do event-related analysis to get features about epochs surrounding the events
        epochs, epoch_features, non_epoch_features = get_epoch_features(signals, events, info, plot=0)
        non_epoch_features.index = [i]

        epoch_signal_df = pd.concat(epochs.values())
        epoch_signal_df.insert(0, 'subject', i)

        epoch_signal_df['event_name'] = [event_name for event_name in events['event_names'] for _ in
                                         range(len(epochs[next(iter(epochs))]))]
        epoch_signal_df['design'] = [design for design in events['designs'] for _ in
                                     range(len(epochs[next(iter(epochs))]))]

        # Load subject's questionnaire answers
        # part1_answers, iuipc_score = load_questionnaire(subject)
        df = (epoch_features
              # .join(part1_answers)
              )
        df.insert(0, 'subject', i)

        epoch_signal_dfs.append(epoch_signal_df)

        # Load subject's demographics
        # meta_df = load_demographics(subject)
        # meta_df.index = [i]
        # meta_df['iuipc'] = iuipc_score

        non_epoch_features_dfs.append(non_epoch_features)
        # dfs.append(df)
        # meta_dfs.append(meta_df)

    non_epoch_df = pd.concat(non_epoch_features_dfs)
    non_epoch_df.to_csv('non_epoch_features.csv')

    # df = pd.concat(dfs)
    df.to_csv('epoch_features.csv')

    epoch_signal_df = pd.concat(epoch_signal_dfs)
    epoch_signal_df.to_csv('epoch_signals.csv')
    #
    # meta_df = pd.concat(meta_dfs)
    # meta_df.to_csv('metadata.csv')
else:
    # Processed data has already been saved, so load it
    df = pd.read_csv('epoch_features.csv', index_col=0)
    # non_epoch_df = pd.read_csv('non_epoch_features.csv', index_col=0)
    epoch_signal_df = pd.read_csv('epoch_signals.csv', index_col=0)
    # meta_df = pd.read_csv('metadata.csv', index_col=0)
Reading xdf file for subject: 002
Stream data after filtering: 377640 samples
Target sampling rate: 2000 Hz
No description has been provided for this image

Data Analysis¶

In [14]:
import matplotlib.pyplot as plt
import warnings

columns_to_plot = ['EDA_Tonic', 'EDA_Phasic']
sampling_rate = 2000

unique_events = epoch_signal_df["event_name"].unique()
designs = ["fair", "dark"]

fig, axes = plt.subplots(len(unique_events), 2, figsize=(10, 3 * len(unique_events)))

for row_idx, event_name in enumerate(unique_events):
    for col_idx, design in enumerate(designs):
        label = f"{design}_{event_name}"

        rows_matching = df[(df['event_name'] == event_name) & (df['design'] == design)]
        print(f"Event: {label}")
        print(f"Mean phasic EDA: {rows_matching['EDA_Phasic_Mean'].mean():.3f}")
        print(f"Max. tonic EDA: {rows_matching['EDA_Tonic_Max'].max():.3f}")
        print()

        # Filter signal rows
        signal_rows = epoch_signal_df[
            (epoch_signal_df["event_name"] == event_name) &
            (epoch_signal_df["design"] == design)
            ][["subject", *columns_to_plot]]

        if signal_rows.empty:
            print(f"Warning: No data available for {label}. Skipping plot.")
            axes[row_idx, col_idx].text(0.5, 0.5, f"No data for {label}",
                                        horizontalalignment='center',
                                        verticalalignment='center',
                                        transform=axes[row_idx, col_idx].transAxes)
            axes[row_idx, col_idx].set_xlabel("Seconds")
            continue

        grouped_by_subject = signal_rows.groupby("subject")
        per_subject_epochs = [group.reset_index(drop=True) for _, group in grouped_by_subject]
        averaged_epoch = pd.concat(per_subject_epochs).groupby(level=0).mean()
        averaged_epoch.index = [samples_to_seconds(i, sampling_rate) - 1 for i in averaged_epoch.index]

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            ax = axes[row_idx, col_idx]
            ax2 = ax.twinx()

            # Plot tonic on left axis, phasic on right
            averaged_epoch['EDA_Tonic'].plot(ax=ax, color='tab:blue', label='EDA_Tonic')
            averaged_epoch['EDA_Phasic'].plot(ax=ax2, color='tab:orange', label='EDA_Phasic')

            # Axis styling
            ax.set_ylabel("Tonic EDA", color='tab:blue')
            ax2.set_ylabel("Phasic EDA", color='tab:orange')

            ax.axvline(x=0, color="grey", linestyle="--")
            ax.set_title(f"{event_name} ({design})")
            ax.set_xlabel("Seconds")

plt.tight_layout()
plt.show()
Event: fair_cookies
Mean phasic EDA: -3.283
Max. tonic EDA: 180.238

Event: dark_cookies
Mean phasic EDA: -5.361
Max. tonic EDA: 279.273

Event: fair_geolocation
Mean phasic EDA: 5.552
Max. tonic EDA: 186.961

Event: dark_geolocation
Mean phasic EDA: 0.668
Max. tonic EDA: 296.842

Event: fair_notification
Mean phasic EDA: -0.078
Max. tonic EDA: 193.945

Event: dark_notification
Mean phasic EDA: -3.897
Max. tonic EDA: 311.563

Event: fair_travelProtection
Mean phasic EDA: 1.966
Max. tonic EDA: 259.060

Event: dark_travelProtection
Mean phasic EDA: 10.267
Max. tonic EDA: 336.185

Event: fair_newsletter
Mean phasic EDA: 0.305
Max. tonic EDA: 264.092

Event: dark_newsletter
Mean phasic EDA: -2.563
Max. tonic EDA: 329.672

No description has been provided for this image
In [ ]: